import numpy
import os
import matplotlib.pyplot as pyplot
import scipy.optimize as optimise
from scipy.signal import find_peaks as find_peaks
from scipy import constants
from jqc import jqc_plot
import matplotlib.gridspec as gridspec

grid = gridspec.GridSpec(1,2,width_ratios=[3,1])

bohr = constants.physical_constants['Bohr radius'][0]
Hartree = constants.physical_constants['atomic unit of energy'][0]
h = constants.h
autoTHz = 1e-12*Hartree/h

jqc_plot.plot_style("wide")

cwd = os.path.dirname(os.path.abspath(__file__))

fig = pyplot.figure("Spectroscopy")
ax0 = fig.add_subplot(grid[0])

ax1 = ax0.twiny()

centres =[]
err_centres=[]
def Gaussian(x,x0,sigma,A):
    #A simple Gaussian function, the width is one standard deviation.
    gauss = A*numpy.exp(-(x-x0)**2/(2*sigma**2))
    return gauss

def GetAveragedData(Data):
    #Sort data by first column
    Data = Data[numpy.argsort(Data[:,0])]

    SectDat = Data[0,:] #Initilise arrays to store future data
    AveragedData = numpy.zeros(3)

    for i in range(1, len(Data[:,0])):
        if Data[i,0] == Data[i-1,0]:
            #Array with all same x
            SectDat = numpy.vstack((SectDat, Data[i,:]))
        else:
            #When x is about to change, average y over that x
            AverageN = numpy.average(SectDat[:,1])+1e-10
            ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

            AveragedData = numpy.vstack((AveragedData,
                                numpy.array([SectDat[0,0], AverageN, ErrN])))
            SectDat = Data[i,:] #Reinitialise same time data store
    #Need to get data from last x

    AverageN = numpy.average(SectDat[:,1])
    ErrN = numpy.std(SectDat[:,1])/numpy.sqrt(len(SectDat[:,1]))

    AveragedData = numpy.vstack((AveragedData,
                                numpy.array([SectDat[0,0], AverageN, ErrN])))
    #Delete initial row of zeros
    AveragedData = numpy.delete(AveragedData, 0, axis=0)
    return AveragedData

doublet = lambda x,y0,x1,s1,A1,x2,s2,A2: y0*numpy.exp(Gaussian(x,x1,s1,A1)+Gaussian(x,x2,s2,A2))

singlet = lambda x,y0,x1,s1,A1: y0*numpy.exp(Gaussian(x,x1,s1,A1))

singlet_pars = numpy.genfromtxt(cwd+"\\singlet.csv",delimiter=',',skip_header=1,
                dtype=None,encoding=None)
doublet_pars = numpy.genfromtxt(cwd+"\\doublet.csv",delimiter=',',skip_header=1,
                dtype=None,encoding=None)
offset = 281.6346e12

centres=[]
err_centres = []

for i,x in enumerate(singlet_pars):
    curve = [x[j] for j in range(1,5)]
    err  = [x[j] for j in range(5,9)]
    fname = x[0]
    data = numpy.genfromtxt(cwd+"\\"+fname+".csv",delimiter=',',skip_header=1)
    data=data[:,1:]
    if i ==0:
        offset2 = curve[1]
    data[:,0]=2*(1e12*data[:,0]-offset)
    data_avg = GetAveragedData(data)


    errorbars = (data_avg[:,1]/curve[0])*numpy.sqrt((data_avg[:,2]/data_avg[:,1])**2+(err[0]/curve[0])**2)
    freqs = numpy.linspace(numpy.amin(data_avg[:,0]),numpy.amax(data_avg[:,0]),500)


    if fname[2] =='0':
        colour = jqc_plot.colours['blue']
        highlight= jqc_plot.colours['grayblue']
        ax0.errorbar(1e-9*(data_avg[:,0]-offset2),data_avg[:,1]/curve[0],yerr=errorbars,
                    color=colour,capsize=3.5,fmt='o',zorder=1.5)
        ax0.plot(1e-9*(freqs-offset2),singlet(freqs,*curve)/curve[0],color=highlight,
                    zorder=1.2)
        centres.append(curve[1])
        err_centres.append(err[1])
    else:
        colour = jqc_plot.colours['red']
        highlight = jqc_plot.colours['reddish']
        ax1.errorbar(1e-9*(data_avg[:,0]-offset2),data_avg[:,1]/curve[0],yerr=errorbars,
                    color=colour,capsize=3.5,fmt='o',zorder=1.5)
            #plot fitted curve
        ax1.plot(1e-9*(freqs-offset2),singlet(freqs,*curve)/curve[0],color=highlight,
                    zorder=1.2)
        centres.append(curve[1])
        err_centres.append(err[1])

for k,x in enumerate(doublet_pars):
    fname = x[0]
    data = numpy.genfromtxt(cwd+"\\"+fname+".csv",delimiter=',',skip_header=1)
    data=data[:,1:]
    data[:,0]=2*(1e12*data[:,0]-offset)
    data_avg = GetAveragedData(data)
    curve = [x[j] for j in range(1,8)]
    err  = [x[j] for j in range(8,14)]

    errorbars = (data_avg[:,1]/curve[0])*numpy.sqrt((data_avg[:,2]/data_avg[:,1])**2+(err[0]/curve[0])**2)
    freqs = numpy.linspace(numpy.amin(data_avg[:,0]),
                                    numpy.amax(data_avg[:,0]),
                                    500)

    if fname[2] =='0':
        colour = jqc_plot.colours['blue']
        highlight= jqc_plot.colours['grayblue']
        ax0.errorbar(1e-9*(data_avg[:,0]-offset2),data_avg[:,1]/curve[0],yerr=errorbars,
                    color=colour,capsize=3.5,fmt='o',zorder=1.5)
        ax0.plot(1e-9*(freqs-offset2),doublet(freqs,*curve)/curve[0],color=highlight,
                    zorder=1.2)
        centres.append(curve[1])
        centres.append(curve[4])
        err_centres.append(err[1])
        err_centres.append(err[4])
    else:
        colour = jqc_plot.colours['red']
        highlight = jqc_plot.colours['reddish']
        ax1.errorbar(1e-9*(data_avg[:,0]-offset2),data_avg[:,1]/curve[0],yerr=errorbars,
                    color=colour,capsize=3.5,fmt='o',zorder=1.5)
            #plot fitted curve
        ax1.plot(1e-9*(freqs-offset2),doublet(freqs,*curve)/curve[0],color=highlight,
                    zorder=1.2)

        centres.append(curve[1])
        centres.append(curve[4])
        err_centres.append(err[1])
        err_centres.append(err[4])

ax0.set_ylim(-0.2,1.4)

ax0.set_xlim(-1,5.5)
xmin,xmax = ax0.get_xlim()
ax1.set_xlim(xmin-0.980231,xmax-0.980231)

ax0.errorbar([],[],yerr=[],color=jqc_plot.colours['blue'],capsize=3.5,fmt='o',
            label="$N=0$")
ax0.errorbar([],[],yerr=[],color=jqc_plot.colours['red'],capsize=3.5,fmt='o',
            label="$N=1$")

ax0.legend(frameon=False,loc='lower left',bbox_to_anchor=(0.6,0.1))

#ax0.set_xlim(-1,5.5)

ax0.set_xlabel("2 $\\times$ [Laser Frequency - $f_0$] (GHz)",
                color=jqc_plot.colours['blue'])
ax0.tick_params(axis='x',which='both', colors=jqc_plot.colours['blue'])
ax1.set_xlabel("2 $\\times$ [Laser Frequency - $f_0$] (GHz)",color=jqc_plot.colours['red'])
ax1.tick_params(axis='x',which='both', colors=jqc_plot.colours['red'])

ax0.set_ylabel("Fraction of Molecules Remaining")

ax1.spines['top'].set_color(jqc_plot.colours['red'])
ax1.spines['bottom'].set_color(jqc_plot.colours['blue'])

ax0.text(0.01,0.05,"(a)",transform=ax0.transAxes,fontsize=20)

ax_potentials = fig.add_subplot(grid[1])

energy = 281.63471 #9394.08 #wavenumbers
waveno_to_GHz = 29.9792458 #speed of light in cm per ns

red = jqc_plot.colours['red']
reddish = jqc_plot.colours['reddish']
lineblue = jqc_plot.colours['grayblue']
blue = jqc_plot.colours['blue']
green = jqc_plot.colours['green']
greenish = jqc_plot.colours['greenish']
sand = jqc_plot.colours['sand']



cwd = os.path.dirname(os.path.abspath(__file__))

for x,file in enumerate(["1S","3P","3S","1P","D"]):
    data = numpy.genfromtxt(cwd+"\\Potentials\\NONSO\\"+ file+".csv",
            delimiter =',',skip_header=1)
    data[:,0]=data[:,0]*bohr*1e9
    data[:,1:]= data[:,1:]*autoTHz
    if x ==0:
        minimum = numpy.amin(data[:,1])
        ax_potentials.plot(data[:,0],(data[:,1]-minimum),color=red,zorder=2)
        min_loc = data[:,0][numpy.where(data[:,1]==minimum)[0]][0]
        ax_potentials.plot(data[:,0],(data[:,2]-minimum),color=reddish,zorder=1.5)
        print("A1Sigma",numpy.amin(data[:,2])-minimum)
        ax_potentials.plot(data[:,0],(data[:,3]-minimum),color=sand,zorder=1.)
        ax_potentials.plot(data[:,0],(data[:,4]-minimum),color=sand,zorder=1.)
        ax_potentials.plot(data[:,0],(data[:,5]-minimum),color=reddish,zorder=1.5)
        ax_potentials.plot(data[:,0],(data[:,6]-minimum),color=sand,zorder=1.)
    elif x==1:
        ax_potentials.plot(data[:,0],(data[:,1]-minimum),color=greenish,zorder=1.4)
        ax_potentials.plot(data[:,0],(data[:,2:]-minimum),color=sand,zorder=1.)
    elif x==2:
        ax_potentials.plot(data[:,0],(data[:,1]-minimum),color=green,zorder=1.4)
        ax_potentials.plot(data[:,0],(data[:,2:]-minimum),color=sand,zorder=1.)
    elif x==3:
        ax_potentials.plot(data[:,0],(data[:,1]-minimum),color=reddish,zorder=1.4)
        print("B1Pi",numpy.amin(data[:,1])-minimum)
        ax_potentials.plot(data[:,0],(data[:,2:]-minimum),color=sand,zorder=1.)

    elif x>2:
        ax_potentials.plot(data[:,0],(data[:,2:]-minimum),color=sand,zorder=1.)

props = dict(arrowstyle="->",color='k',lw=2.5,shrinkA=0,shrinkB=0)
ax_potentials.annotate("",(min_loc,energy),(min_loc,0),arrowprops=props)
ax_potentials.annotate("",(min_loc,2*energy),(min_loc,energy),arrowprops=props)

ax_potentials.text(0.63,50,"$X^1\\Sigma^+$",color=red,transform=ax_potentials.transData,fontsize=12)

ax_potentials.text(0.75,130,"$a^3\\Sigma^+$",color=green,transform=ax_potentials.transData,fontsize=12)

ax_potentials.text(0.48,375,"$B^1\\Pi$",color=reddish,transform=ax_potentials.transData,fontsize=12)

ax_potentials.text(0.5,255,"$b^3\\Pi$",color=greenish,transform=ax_potentials.transData,fontsize=12)

ax_potentials.text(0.7,320,"$A^1\\Sigma^+$",color=reddish,transform=ax_potentials.transData,fontsize=12)

ax_potentials.text(0.6,550,"$(5)^1\\Sigma^+$",color=reddish,transform=ax_potentials.transData,fontsize=12)


ax_potentials.set_xlabel("Nuclear Separation (nm)")
ax_potentials.set_ylabel("Energy/$h$ (THz)")

ax_potentials.set_xlim(.29,1.0)
ax_potentials.set_ylim(-31,620)
ax_potentials.text(0.85,0.05,"(b)",transform=ax_potentials.transAxes,fontsize=20)

pyplot.tight_layout()
fname = os.path.splitext(os.path.basename(__file__))[0]
ymin,ymax=ax0.get_ylim()


#manually input the stronger components in energy order
N0 = centres[0]
N1 = centres[5]+980e6
N2 = centres[4]
N3 = centres[8]+980e6

err0 = err_centres[0]
err1 = err_centres[5]
err2 = err_centres[4]
err3 = err_centres[8]

N=numpy.array([N0,N1,N2,N3])-N0
err = numpy.array([err0,err1,err2,err3])
print(N/2,err/2)

model = lambda N,B: B*N*(N+1)
#calculate rotational constant
curve,cov = optimise.curve_fit(model,[0,1,2,3],N,sigma=err,absolute_sigma=True,p0=400e6)
cov = numpy.sqrt(numpy.diag(cov))

print(curve*1e-6,cov*1e-6)

print((centres-offset2)/2,numpy.array(err_centres)/2)

#Brot = 400e6
Brot = curve[0]

for i in range(0,4):
    rot = 1e-9*(i*(i+1)*Brot+0)
    ax1.plot([rot,rot],[ymin,ymax],color='k',ls='--',zorder=3,transform=ax0.transData)
    ax1.text(rot+0.02,-0.12,"$N'=${:.0f}".format(i),transform=ax0.transData)

pyplot.savefig(cwd+"\\Output\\"+fname+'.pdf')
pyplot.savefig(cwd+"\\Output\\"+fname+'.png')

pyplot.show()
